package com.liusu.deeplearning4j;

import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;

public class MnistLoader {
    public static DataSetIterator getMnistTrainData() throws Exception {
        return new MnistDataSetIterator(64, true, 12345);
    }

    public static DataSetIterator getMnistTestData() throws Exception {
        return new MnistDataSetIterator(64, false, 12345);
    }
}
